
package com.jstarcraft.ai.jsat.linear;

import java.util.Arrays;
import java.util.Iterator;
import java.util.NoSuchElementException;

/**
 * This class is used to create an implicit representation of the degree 2
 * polynomial of an input vector, with an implicit bias term added so that the
 * original vector values are present in the implicit vector. This means no
 * extra memory will be allocated, and all values accessed will be re-computed
 * as needed. This works with sparse vectors, and work bet with algorithms that
 * iterate over the nonzero values once. <br>
 * <br>
 * Any change in the base vector will change the values in this vector. Because
 * changing one value in the base effects multiple values in this one, altering
 * this vector directly is not allowed. <br>
 * <br>
 * If the base vector has {@code N} non zero values, then this vec will have
 * O(N<sup>2</sup>) non zero values. (N+2)(N+1)/2 non zero values to be exact.
 *
 * @author Edward Raff
 */
public class Poly2Vec extends Vec {

    private static final long serialVersionUID = -5653680966558726340L;

    private Vec base;

    /**
     * This maps values pas the original coefficients (and bias term) shifted to
     * start from zero, to the appropriate value for the fist coefficient. <br>
     * This will be created lazily as needed. Call {@link #getReverseIndex() } to
     * access this value
     */
    private int[] reverseIndex;

    public Poly2Vec(Vec base) {
        setBase(base);
    }

    /*
     * Some math needed for this class to make sense. Given an input we want to poly
     * 2 form plus a bias term. So for (x + y + z) we want (1 + x+ y + z + x^2 + x y
     * + x z + y^2 + y z + z^2)
     * 
     * Then for an input of size N, the poly 2 version has length (N+2)(N+1)/2
     *
     * The bias term and maintaining the original is easy. So lets assume we only
     * want to get the value for the x^2 term and after. IE: given a term x and y,
     * give me the index of the coeff that contains their product. Let x start from
     * 0 and let x^2 also start from zero, so we map from one space to the other.
     * 
     * The exact index location, when x <= y, is then x N + y - x (x+1) / 2
     * 
     */

    /**
     * Creates a new vector that implicitly represents the degree 2 polynomial of
     * the base vector.
     *
     * @param base the base vector
     */
    public void setBase(Vec base) {
        this.base = base;
    }

    private int[] getReverseIndex() {
        if (reverseIndex != null && reverseIndex.length == base.length())
            Arrays.fill(reverseIndex, 0);
        else
            reverseIndex = new int[base.length()];
        reverseIndex[0] = base.length();
        for (int i = 1; i < reverseIndex.length; i++)
            reverseIndex[i] = reverseIndex[i - 1] + (base.length() - i);
        return reverseIndex;
    }

    @Override
    public int length() {
        return (base.length() + 2) * (base.length() + 1) / 2;
    }

    @Override
    public int nnz() {
        return (base.nnz() + 2) * (base.nnz() + 1) / 2;
    }

    @Override
    public double get(int index) {
        if (index == 0)
            return 1;
        else if (index <= base.length())
            return base.get(index - 1);
        else if (index >= length())
            throw new IndexOutOfBoundsException("Vector is of length " + length() + ", but index " + index + " was requested");
        int x = Arrays.binarySearch(getReverseIndex(), index - base.length() - 1);
        if (x < 0)
            x = -x - 1;
        else
            x++;
        double xVal = base.get(x);

        int y = (x * x + x) / 2 + (index - base.length() - 1) - base.length() * x;// the first term is safe b/c it will always be an even number before division
        return xVal * base.get(y);
    }

    @Override
    public void set(int index, double val) {
        throw new UnsupportedOperationException("Poly2Vec may not be altered");
    }

    @Override
    public boolean isSparse() {
        return base.isSparse();
    }

    @Override
    public Vec clone() {
        return new Poly2Vec(base.clone());
    }

    @Override
    public Iterator<IndexValue> getNonZeroIterator(int start) {
        // First case: empty base vector
        if (base.nnz() == 0)
            return new Iterator<IndexValue>() {
                boolean hasNext = true;

                @Override
                public boolean hasNext() {
                    return hasNext;
                }

                @Override
                public IndexValue next() {
                    if (!hasNext)
                        throw new NoSuchElementException("Iterator is empty");
                    hasNext = false;
                    return new IndexValue(0, 1.0);
                }

                @Override
                public void remove() {
                    throw new UnsupportedOperationException("Not supported yet."); // To change body of generated methods, choose Tools | Templates.
                }
            };
        // Else, general case
        final int startStage;
        final Iterator<IndexValue> startOuterIter, startInerIter;
        boolean stage1Good = true;// fail occurs when the last index (or more) in the base vector is zero
        if (start == 0) {
            startStage = 0;
            startInerIter = startOuterIter = null;
        } else if (start <= base.length() && (stage1Good = base.getNonZeroIterator(start - 1).hasNext())) {
            startStage = 1;
            startOuterIter = base.getNonZeroIterator(start - 1);
            startInerIter = null;
        } else if (start >= length()) {
            startStage = 3;
            startInerIter = startOuterIter = null;
        } else// where do we start?
        {
            if (!stage1Good)
                start = base.length() + 1;
            Iterator<IndexValue> candidateOuterIter, candidateInerIter;
            start--;// lazy ness so we can update first thing in each iteration (we dont actually
                    // want to change the first value in the looping
            do {
                start++;
                int x = Arrays.binarySearch(getReverseIndex(), start - base.length() - 1);
                if (x < 0)
                    x = -x - 1;
                else
                    x++;
                int y = (x * x + x) / 2 + (start - base.length() - 1) - base.length() * x;// the first term is safe b/c it will always be an even number before division
                candidateOuterIter = base.getNonZeroIterator(x);
                /*
                 * If the x coefficeint is zero, we will jump to the next non zero x. This means
                 * y must change as well, so we will check if that has happened by grabbing
                 * another iterator to get the value. If this has happened, we know that y
                 * should be set to x's value
                 */
                int nextXIndex = candidateOuterIter.hasNext() ? base.getNonZeroIterator(x).next().getIndex() : -1;
                if (candidateOuterIter.hasNext() && nextXIndex > x)// x is at a zero, so we need to inner iter to go back to the "begining"
                    candidateInerIter = base.getNonZeroIterator(nextXIndex);// next variable starts at val^2
                else
                    candidateInerIter = base.getNonZeroIterator(y);
            } while ((!candidateOuterIter.hasNext() || !candidateInerIter.hasNext()) && start < length());
            if (candidateOuterIter.hasNext() && candidateInerIter.hasNext() && start < length()) {
                startStage = 2;
                startOuterIter = candidateOuterIter;
                startInerIter = candidateInerIter;
            } else
                return new Iterator<IndexValue>() {
                    @Override
                    public boolean hasNext() {
                        return false;
                    }

                    @Override
                    public IndexValue next() {
                        throw new NoSuchElementException("Iterator is empty");
                    }

                    @Override
                    public void remove() {
                        throw new UnsupportedOperationException("Not supported yet.");
                    }
                };
        }

        return new Iterator<IndexValue>() {
            int stage = startStage;// 0 is for bias, 1 is for stanrdard values, 2 is for combinations, 3 is for
                                   // empty

            Iterator<IndexValue> outerIter = startOuterIter, inerIter = startInerIter;
            IndexValue curOuterVal = inerIter != null ? outerIter.next() : null;
            IndexValue toReturn = new IndexValue(0, 0);

            @Override
            public boolean hasNext() {
                if (stage < 3)
                    return true;
                return false;
            }

            @Override
            public IndexValue next() {
                if (stage == 0) {
                    stage++;
                    outerIter = base.getNonZeroIterator();// we know its non empty b/c of first case
                    return new IndexValue(0, 1.0);
                } else if (stage == 1)// outerIter must always have a next item if stage = 1
                {
                    IndexValue iv = outerIter.next();
                    if (!outerIter.hasNext()) {
                        stage++;
                        outerIter = base.getNonZeroIterator();
                        curOuterVal = outerIter.next();
                        inerIter = base.getNonZeroIterator();
                    }
                    toReturn.setIndex(1 + iv.getIndex());
                    toReturn.setValue(iv.getValue());
                    return toReturn;
                } else if (stage == 2) {
                    IndexValue innerVal = inerIter.next();
                    int x = curOuterVal.getIndex();
                    int y = innerVal.getIndex();
                    int N = base.length();
                    toReturn.setIndex(1 + N + x * N + y - x * (x + 1) / 2);
                    toReturn.setValue(curOuterVal.getValue() * innerVal.getValue());

                    if (!inerIter.hasNext()) {
                        if (!outerIter.hasNext())// we are out!
                        {
                            stage++;
                            outerIter = inerIter = null;
                        } else// Still at least one more round!
                        {
                            curOuterVal = outerIter.next();
                            // new inner itter starts at idx^2
                            inerIter = base.getNonZeroIterator(curOuterVal.getIndex());
                        }
                    }

                    return toReturn;
                } else // stage >= 3
                    throw new NoSuchElementException("Iterator is empty");
            }

            @Override
            public void remove() {
                throw new UnsupportedOperationException("Not supported yet.");
            }
        };
    }

    @Override
    public void setLength(int length) {
        throw new UnsupportedOperationException("Poly2Vec can't extend original base vector");
    }

}
